From the Supplementary Tables, export Sheet
Table S9-Gene Counts CPM to
Gene_Counts_CPM.csv and sheet
Table S10-Drug Responses to
Drug_Responses.csv. Place both .csv files to a folder with
the following path datasets/beataml/real/.
You can also download the .csv files directly from our public link, https://cloud.cing.ac.cy/index.php/s/d9npoZDR7HebnSf
library(data.table)
library(magrittr)
# Read file ####
cpm <- fread(file = "../../datasets/beataml/real/Gene_Counts_CPM.csv", check.names = TRUE)
cpm_mat <- as.matrix(cpm[, -c(1, 2)])
rownames(cpm_mat) <- cpm$Gene
# Transpose ####
cpm_mat <- t(cpm_mat)
# Convert to data.table
cpm_preprocess <- data.table(cpm_mat, keep.rownames = TRUE)
colnames(cpm_preprocess)[1] <- "labid"
# Read file ####
drug_response <- fread(file = "../../datasets/beataml/real/Drug_Responses.csv")
drug_response$inhibitor <- make.names(drug_response$inhibitor)
drug_response$lab_id <- make.names(drug_response$lab_id)
colnames(drug_response)[2] <- "labid"
cat("In total we have", length(unique(drug_response[,inhibitor])), "drugs\n")
## In total we have 122 drugs
In the data there are four drugs that are used in clinical practice Gilteritinib, Lenalidomide, Midostaurin, and Venetoclax. We will use Venetoclax as an example. We will convert the problem to a classification. For educational purposes, we use the median AUC value to define drug response classes (Sensitive, Resistant).
# Number of samples per drug
drug_response[
inhibitor %in% c("Venetoclax", "Midostaurin",
"Gilteritinib..ASP.2215.",
"Lenalidomide"
),inhibitor] %>%
table() %>% sort(decreasing = TRUE)
## .
## Midostaurin Venetoclax Gilteritinib..ASP.2215.
## 423 295 191
## Lenalidomide
## 177
# Select drug
drug_j <- "Venetoclax"
drug_response_j <- drug_response[inhibitor == drug_j, c("labid", "auc")]
# Convert the problem to a classification. For educational purposes, we use the median AUC value to define drug response classes.
drug_response_j[
, auc_binary := ifelse(test = auc <= median(drug_response_j[, auc]),
yes = "Sensitive",
no = "Resistant"
)
]
# Merge data
data_all <- merge.data.table(
x = drug_response_j[, c("labid", "auc_binary")],
y = cpm_preprocess,
by = "labid")
head(data_all)[, 1:4]
## Key: <labid>
## labid auc_binary ENSG00000000003 ENSG00000000419
## <char> <char> <num> <num>
## 1: X14.00739 Sensitive 0.06837848 28.24031
## 2: X14.00781 Resistant 0.03291972 40.75461
## 3: X14.00787 Sensitive 0.00000000 39.25088
## 4: X14.00798 Resistant 0.00000000 32.71511
## 5: X14.00815 Resistant 1.35469640 37.53125
## 6: X14.00817 Resistant 0.13684772 16.83227
# Create X and y
X <- as.matrix(data_all[, -c("labid", "auc_binary")])
y <- data_all[, auc_binary]
Before proceeding with supervised model training, we explore the structure of the data using unsupervised learning techniques. These methods help uncover hidden patterns, detect outliers, and assess sample groupings without using the response variable.
# Calculate standard deviation
gene_sd <- apply(X = X, MARGIN = 2, FUN = sd)
hist(gene_sd, xlab = "Genes standard deviation")
# Select genes
genes2keep <- names(sort(x = gene_sd, decreasing = TRUE)[1:50])
X_50 <- X[, which(colnames(X) %in% genes2keep)]
We apply PCA to reduce dimensionality and visualize major sources of variance in the data.
# Standardize data
X_scaled <- scale(X_50)
# Apply PCA
pca <- prcomp(X_scaled)
# Variance explained
explained_var <- pca$sdev^2 / sum(pca$sdev^2)
# Scree plot: standard deviation of each PC
scree_df <- data.frame(PC = 1:length(pca$sdev),
VarExplained = explained_var)
library(ggplot2)
ggplot(scree_df, aes(x = PC, y = VarExplained)) +
geom_line() +
geom_point() +
labs(title = "Scree Plot",
x = "Principal Component",
y = "Proportion of variable explained") +
theme_minimal()
# Visualize first two principal components
pca_df <- data.frame(PC1 = pca$x[, 1], PC2 = pca$x[, 2], Response = y)
ggplot(pca_df, aes(PC1, PC2, color = Response)) +
geom_point() +
labs(title = "PCA of gene expression data") +
theme_minimal()
We apply clustering algorithms (e.g., k-means and hierarchical clustering) to discover natural groupings in the samples.
kClust <- kmeans(scale(X_50), centers=2, nstart = 1000, iter.max = 2000)
kClusters <- as.character(kClust$cluster)
annotation_col <- data.frame(Response = y)
rownames(annotation_col) <- data_all$labid
rownames(X_50) = rownames(annotation_col)
pca_data <- as.data.frame(pca$x[, 1:2]) # Get the first two principal components
pca_data$Cluster <- as.factor(kClust$cluster)
True_Response <- y
#colnames(annotation_true)="Response"
# Plot
ggplot(pca_data, aes(x = PC1, y = PC2, color = Cluster,shape = True_Response)) +
geom_point(size = 2) +
theme_minimal() +
labs(title = "K-means Clustering (k = 2) Visualized by PCA")
# Scale and transpose the data so samples are rows again (if they aren't)
X_scaled <- scale(X_50) # Genes as columns, samples as rows
# Create annotation for pheatmap
annotation_col <- data.frame(y)
colnames(annotation_col)="Response"
rownames(annotation_col) <- data_all$labid
rownames(X_50) = rownames(annotation_col)
library(pheatmap)
pheatmap(t(X_scaled), # transpose so genes are rows, samples are columns
annotation_col = annotation_col,
show_rownames = TRUE,
show_colnames = TRUE,
clustering_distance_cols = "euclidean",
clustering_method = "ward.D2",
fontsize_row = 6,
fontsize_col = 6,
fontsize = 8,
main = "Hierarchical Clustering")
We will keep 70% of the data for training and 30% for testing. Train and test data partitions will contain the same class representation distribution as the whole dataset - stratified data splitting.
library(caret)
## Loading required package: lattice
# Split the data into stratified train/test sets (70/30 split)
set.seed(42)
trainIndex <- createDataPartition(y, p = 0.7, list = FALSE)
X_train <- X[trainIndex, ]
X_test <- X[-trainIndex, ]
y_train <- y[trainIndex] %>% factor()
y_test <- y[-trainIndex] %>% factor()
We retain the top 50 most variable genes across samples, assuming they carry the most discriminative signal for drug response.
# Calculate standard deviation
gene_sd <- apply(X = X_train, MARGIN = 2, FUN = sd)
hist(gene_sd, xlab = "Genes standard deviation")
# Select genes
genes2keep <- names(sort(x = gene_sd, decreasing = TRUE)[1:50])
# Make new train and test
X_train <- X_train[, which(colnames(X_train) %in% genes2keep)]
X_test <- X_test[, which(colnames(X_test) %in% genes2keep)]
# Combine predictors and response into a single data frame
# Scale training data
X_train_scaled <- scale(X_train)
# Scale test data using training mean and sd
X_test_scaled <- scale(X_test,
center = attr(X_train_scaled, "scaled:center"),
scale = attr(X_train_scaled, "scaled:scale"))
# Center and scale training data
train_df <- as.data.frame(X_train_scaled)
train_df$Response <- ifelse(y_train == "Sensitive", 1, 0)
# Fit logistic regression model
glm_model <- glm(Response ~ ., data = train_df, family = binomial)
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
# Summary of the model
summary(glm_model)
##
## Call:
## glm(formula = Response ~ ., family = binomial, data = train_df)
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -4.796e+14 5.863e+06 -81788164 <2e-16 ***
## ENSG00000005381 4.943e+14 1.431e+07 34541824 <2e-16 ***
## ENSG00000012223 -1.562e+14 1.115e+07 -14005594 <2e-16 ***
## ENSG00000019582 4.328e+14 9.283e+06 46620602 <2e-16 ***
## ENSG00000026025 1.090e+14 1.039e+07 10491268 <2e-16 ***
## ENSG00000038427 -7.563e+14 1.216e+07 -62191172 <2e-16 ***
## ENSG00000044574 1.148e+15 1.467e+07 78261304 <2e-16 ***
## ENSG00000070756 -3.234e+14 9.725e+06 -33257854 <2e-16 ***
## ENSG00000075624 3.082e+13 1.699e+07 1813992 <2e-16 ***
## ENSG00000087086 4.497e+14 1.846e+07 24363918 <2e-16 ***
## ENSG00000090382 -5.238e+14 1.289e+07 -40629027 <2e-16 ***
## ENSG00000100448 1.807e+14 1.114e+07 16217054 <2e-16 ***
## ENSG00000111640 -2.269e+14 9.845e+06 -23041532 <2e-16 ***
## ENSG00000122862 -2.806e+14 1.678e+07 -16722925 <2e-16 ***
## ENSG00000124942 -4.481e+14 1.452e+07 -30855303 <2e-16 ***
## ENSG00000132475 3.536e+14 1.107e+07 31941831 <2e-16 ***
## ENSG00000133112 1.617e+14 1.306e+07 12379339 <2e-16 ***
## ENSG00000143546 -1.240e+14 2.804e+07 -4423523 <2e-16 ***
## ENSG00000163220 -4.292e+14 2.981e+07 -14399846 <2e-16 ***
## ENSG00000166710 -7.730e+14 1.198e+07 -64546950 <2e-16 ***
## ENSG00000167658 -8.232e+14 1.955e+07 -42098413 <2e-16 ***
## ENSG00000167996 -8.164e+14 1.449e+07 -56345373 <2e-16 ***
## ENSG00000169429 -2.410e+14 9.540e+06 -25264678 <2e-16 ***
## ENSG00000170345 1.168e+15 1.288e+07 90689674 <2e-16 ***
## ENSG00000172232 5.988e+14 1.502e+07 39878284 <2e-16 ***
## ENSG00000177606 -3.526e+14 1.498e+07 -23543542 <2e-16 ***
## ENSG00000179218 -1.049e+15 1.787e+07 -58702423 <2e-16 ***
## ENSG00000196205 3.405e+14 1.843e+07 18474357 <2e-16 ***
## ENSG00000196415 -4.236e+13 1.053e+07 -4021826 <2e-16 ***
## ENSG00000196924 8.202e+14 1.193e+07 68754522 <2e-16 ***
## ENSG00000197746 -2.452e+14 1.865e+07 -13146698 <2e-16 ***
## ENSG00000198034 -7.112e+14 1.740e+07 -40880480 <2e-16 ***
## ENSG00000198712 -1.372e+14 3.786e+07 -3623930 <2e-16 ***
## ENSG00000198727 3.281e+15 5.241e+07 62605201 <2e-16 ***
## ENSG00000198763 -3.599e+14 4.080e+07 -8820718 <2e-16 ***
## ENSG00000198786 -1.017e+14 2.784e+07 -3652954 <2e-16 ***
## ENSG00000198804 -1.626e+14 3.967e+07 -4099091 <2e-16 ***
## ENSG00000198840 -2.025e+15 3.088e+07 -65573372 <2e-16 ***
## ENSG00000198886 -3.656e+15 8.695e+07 -42052274 <2e-16 ***
## ENSG00000198888 1.836e+15 3.537e+07 51910031 <2e-16 ***
## ENSG00000198899 -1.471e+15 5.255e+07 -27988285 <2e-16 ***
## ENSG00000198938 -1.439e+15 3.691e+07 -38999675 <2e-16 ***
## ENSG00000210082 -2.659e+14 1.609e+07 -16527021 <2e-16 ***
## ENSG00000211459 5.903e+14 1.548e+07 38136941 <2e-16 ***
## ENSG00000212907 3.231e+15 6.959e+07 46430736 <2e-16 ***
## ENSG00000228253 6.827e+14 4.873e+07 14009163 <2e-16 ***
## ENSG00000229807 7.753e+14 1.198e+07 64704940 <2e-16 ***
## ENSG00000234745 -7.102e+14 1.368e+07 -51912374 <2e-16 ***
## ENSG00000244734 1.509e+14 7.756e+06 19450682 <2e-16 ***
## ENSG00000248527 2.150e+14 3.074e+07 6993778 <2e-16 ***
## ENSG00000251562 -7.520e+14 9.387e+06 -80108333 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 181.6 on 130 degrees of freedom
## Residual deviance: 1585.9 on 80 degrees of freedom
## AIC: 1687.9
##
## Number of Fisher Scoring iterations: 14
# Predict probabilities on test set
test_df <- as.data.frame(X_test_scaled)
glm_probs <- predict(glm_model, newdata = test_df, type = "response")
# Convert probabilities to class predictions using 0.5 threshold
glm_preds <- ifelse(
glm_probs > 0.5, "Sensitive", "Resistant") %>% factor(
levels = levels(y_test))
# Evaluate performance
confusionMatrix(glm_preds, y_test)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 20 12
## Sensitive 8 15
##
## Accuracy : 0.6364
## 95% CI : (0.4956, 0.7619)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 0.03916
##
## Kappa : 0.2706
##
## Mcnemar's Test P-Value : 0.50233
##
## Sensitivity : 0.7143
## Specificity : 0.5556
## Pos Pred Value : 0.6250
## Neg Pred Value : 0.6522
## Prevalence : 0.5091
## Detection Rate : 0.3636
## Detection Prevalence : 0.5818
## Balanced Accuracy : 0.6349
##
## 'Positive' Class : Resistant
##
Compute the ROC curve manually
# Ground truth (0 = Resistant, 1 = Sensitive)
actual <- ifelse(y_test == "Sensitive", 1, 0)
# Predicted probabilities for "Sensitive" class
probs <- glm_probs
# Define thresholds
thresholds <- seq(from = 0, to = 1, by = 0.01)
# Initialize TPR and FPR vectors
tpr <- rep(x = NA, length(thresholds))
fpr <- tpr
# Loop through thresholds
for (i in seq_along(thresholds)) {
thresh <- thresholds[i]
preds <- ifelse(probs >= thresh, 1, 0)
TP <- sum(preds == 1 & actual == 1)
TN <- sum(preds == 0 & actual == 0)
FP <- sum(preds == 1 & actual == 0)
FN <- sum(preds == 0 & actual == 1)
tpr[i] <- TP / (TP + FN)
fpr[i] <- FP / (FP + TN)
}
# Compute AUC using the trapezoidal rule
# Ensure FPR and TPR are sorted in increasing FPR order
ord <- order(fpr)
fpr_sorted <- fpr[ord]
tpr_sorted <- tpr[ord]
auc <- sum(diff(fpr_sorted) * (head(tpr_sorted, -1) + tail(tpr_sorted, -1)) / 2)
auc
## [1] 0.6349206
# Plot ROC curve
plot(x = fpr, y = tpr, type = "l", col = "blue", lwd = 2,
xlab = "False Positive Rate (1 - Specificity)",
ylab = "True Positive Rate (Sensitivity)",
main = paste("ROC Curve (AUC =", round(auc, 3), ")"))
abline(0, 1, col = "gray", lty = 2)
# Get coefficients (excluding intercept)
coefs <- coef(glm_model)
coefs <- coefs[-1] # Remove intercept
coefs <- sort(coefs, decreasing = TRUE) # Sort by value
# Create a data.frame with absolute values for plotting
imp_df <- data.frame(
Feature = names(coefs),
Coefficient = coefs,
Importance = abs(coefs)
)
# Take top N important features
top_n <- 20
imp_top <- head(imp_df[order(-imp_df$Importance), ], top_n)
library(ggplot2)
ggplot(imp_top, aes(x = reorder(Feature, Importance), y = Coefficient,
fill = Coefficient > 0)) +
geom_col(show.legend = FALSE) +
coord_flip() +
labs(title = "Top 20 Important Features (Logistic Regression)",
x = "Feature",
y = "Coefficient") +
scale_fill_manual(values = c("steelblue", "firebrick")) +
theme_minimal(base_size = 14)
library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-8
set.seed(42)
# Step 1: Prepare data
y_vec <- ifelse(y_train == "Sensitive", 1, 0)
# Step 2: Define folds manually to perform 5 fold cross validation
folds <- sample(x = 1:5, size = nrow(X_train), replace = TRUE)
# Step 3: Define lambda grid manually
lambda_grid <- 10^seq(2, -4, length = 100) # from 100 to 0.0001
# Step 4: Storage for results
cv_results <- matrix(NA, nrow = length(lambda_grid), ncol = 5)
# Step 5: Manual 5-fold CV
for (i in seq_len(length.out = max(folds))) {
cat("Processing Fold", i, "\n")
# Split into train/val
val_idx <- which(folds == i)
X_train_fold <- X_train[-val_idx, ]
y_train_fold <- y_vec[-val_idx]
X_val_fold <- X_train[val_idx, ]
y_val_fold <- y_vec[val_idx]
# Train glmnet model on training fold (all lambdas at once)
fold_model <- glmnet(
x = X_train_fold,
y = y_train_fold,
family = "binomial",
alpha = 1, # You can tune alpha separately too
lambda = lambda_grid
)
# Predict on validation fold
preds <- predict(fold_model, newx = X_val_fold, type = "response")
# preds: matrix of n_val_samples x n_lambda
# Now for each lambda, calculate accuracy (ACC)
for (j in seq_along(lambda_grid)) {
pred_prob <- preds[, j]
# Compute simple accuracy or AUC
pred_class <- ifelse(pred_prob > 0.5, 1, 0)
acc <- mean(pred_class == y_val_fold)
cv_results[j, i] <- acc
}
}
## Processing Fold 1
## Processing Fold 2
## Processing Fold 3
## Processing Fold 4
## Processing Fold 5
colnames(cv_results) <- paste0("Accuracy_fold", 1:5)
cv_results <- cbind(Lambda = lambda_grid, cv_results)
head(cv_results)
## Lambda Accuracy_fold1 Accuracy_fold2 Accuracy_fold3 Accuracy_fold4
## [1,] 100.00000 0.3666667 0.4857143 0.5333333 0.36
## [2,] 86.97490 0.3666667 0.4857143 0.5333333 0.36
## [3,] 75.64633 0.3666667 0.4857143 0.5333333 0.36
## [4,] 65.79332 0.3666667 0.4857143 0.5333333 0.36
## [5,] 57.22368 0.3666667 0.4857143 0.5333333 0.36
## [6,] 49.77024 0.3666667 0.4857143 0.5333333 0.36
## Accuracy_fold5
## [1,] 0.4615385
## [2,] 0.4615385
## [3,] 0.4615385
## [4,] 0.4615385
## [5,] 0.4615385
## [6,] 0.4615385
# Step 6: Aggregate results across folds
mean_cv_accuracy <- rowMeans(cv_results[, -1])
# Step 7: Find best lambda
best_lambda_idx <- which.max(mean_cv_accuracy)
best_lambda <- lambda_grid[best_lambda_idx]
cat("Best lambda:", best_lambda, "\n")
## Best lambda: 0.01321941
# Visualization
df_lambda_cv <- data.frame(
Lambda = lambda_grid,
Accuracy = mean_cv_accuracy
)
ggplot(df_lambda_cv, aes(x = log(Lambda), y = Accuracy)) +
geom_line() +
geom_point() +
geom_vline(xintercept = log(best_lambda), color = "red", linetype = "dashed") +
labs(
title = "Manual CV: Accuracy vs log(Lambda)",
x = "log(Lambda)",
y = "CV Accuracy"
) +
theme_minimal()
# Step 8: Retrain final model on full training set
final_model_manual <- glmnet(
x = X_train,
y = y_vec,
family = "binomial",
alpha = 1,
lambda = best_lambda
)
# Step 9: Evaluate on test set
X_test_mat <- as.matrix(X_test)
probs_test <- predict(final_model_manual, newx = X_test_mat, type = "response")
pred_classes_test <- ifelse(probs_test > 0.5, "Sensitive", "Resistant") %>% as.factor()
confusionMatrix(pred_classes_test, y_test)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 23 4
## Sensitive 5 23
##
## Accuracy : 0.8364
## 95% CI : (0.712, 0.9223)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 4.245e-07
##
## Kappa : 0.6728
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8214
## Specificity : 0.8519
## Pos Pred Value : 0.8519
## Neg Pred Value : 0.8214
## Prevalence : 0.5091
## Detection Rate : 0.4182
## Detection Prevalence : 0.4909
## Balanced Accuracy : 0.8366
##
## 'Positive' Class : Resistant
##
# Compute ROC curve and AUC
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
roc_obj <- roc(y_test, probs_test[,1])
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
auc_val <- auc(roc_obj)
# Plot ROC
plot(roc_obj, col = "#2c3e50", lwd = 2, main = paste("ROC Curve (AUC =", round(auc_val, 3), ")"))
# Extract and clean non-zero coefficients
coef_matrix <- coef(final_model_manual)
coef_df <- as.data.frame(as.matrix(coef_matrix))
coef_df$gene <- rownames(coef_df)
colnames(coef_df)[1] <- "coefficient"
# Remove intercept and zero coefficients
coef_df <- coef_df[coef_df$coefficient != 0 & coef_df$gene != "(Intercept)", ]
# Order by coefficient magnitude
coef_df <- coef_df[order(abs(coef_df$coefficient), decreasing = TRUE), ]
# Load ggplot2 for visualization
library(ggplot2)
# Create the plot
ggplot(coef_df,
aes(x = reorder(gene, coefficient),
y = coefficient, fill = coefficient > 0)) +
geom_bar(stat = "identity", show.legend = FALSE) +
coord_flip() +
labs(title = "Non-Zero Coefficients from Elastic Net Model",
x = "Gene",
y = "Coefficient") +
scale_fill_manual(values = c("firebrick", "steelblue")) +
theme_minimal(base_size = 14)
cv.glmnet# Train elastic net with 5-fold CV
set.seed(42)
cv_fit <- cv.glmnet(x = X_train,
y = y_train,
alpha = 1, # Elastic net: mix between LASSO (1) and Ridge (0)
family = "binomial",
type.measure = "auc", # AUC for classification
nfolds = 5)
# View optimal lambda
cv_fit$lambda.min
## [1] 0.01716512
# Plot CV resamples
plot(cv_fit)
# Predict probabilities on test set
prob_test <- predict(cv_fit, newx = X_test, s = "lambda.min", type = "response")
# Binary prediction
pred_test <- predict(cv_fit, newx = X_test, s = "lambda.min", type = "class")
# or
# pred_test <- ifelse(prob_test > 0.5, 1, 0)
# Confusion matrix
table(Predicted = pred_test, Actual = y_test)
## Actual
## Predicted Resistant Sensitive
## Resistant 23 4
## Sensitive 5 23
# ROC/AUC
library(pROC)
roc_obj <- roc(y_test, as.numeric(prob_test))
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
auc(roc_obj)
## Area under the curve: 0.8757
# Plot ROC
plot(roc_obj, main = paste("Elastic Net AUC:", round(auc(roc_obj), 3)))
# Extract non-zero coefficients at optimal lambda
coef_enet <- coef(cv_fit, s = "lambda.min")
coef_df <- as.data.frame(as.matrix(coef_enet))
coef_df$gene <- rownames(coef_df)
colnames(coef_df)[1] <- "coefficient"
# Keep only non-zero and non-intercept
coef_df <- coef_df[coef_df$coefficient != 0 & coef_df$gene != "(Intercept)", ]
# Sort by magnitude
coef_df <- coef_df[order(abs(coef_df$coefficient), decreasing = TRUE), ]
# View top features
head(coef_df, 10)
## coefficient gene
## ENSG00000234745 -0.0009638116 ENSG00000234745
## ENSG00000196924 0.0005713910 ENSG00000196924
## ENSG00000197746 -0.0004652504 ENSG00000197746
## ENSG00000211459 0.0003037587 ENSG00000211459
## ENSG00000132475 0.0002563555 ENSG00000132475
## ENSG00000167658 -0.0002423296 ENSG00000167658
## ENSG00000143546 -0.0002046767 ENSG00000143546
## ENSG00000170345 0.0002045683 ENSG00000170345
## ENSG00000169429 -0.0001809248 ENSG00000169429
## ENSG00000038427 -0.0001759602 ENSG00000038427
ggplot(coef_df, aes(x = reorder(gene, coefficient), y = coefficient, fill = coefficient > 0)) +
geom_bar(stat = "identity", show.legend = FALSE) +
coord_flip() +
labs(title = "Non-Zero Coefficients from Elastic Net",
x = "Gene",
y = "Coefficient") +
scale_fill_manual(values = c("firebrick", "steelblue")) +
theme_minimal(base_size = 14)
We selected a diverse set of models that represent different ML families:
Elastic Net (linear model with regularization)
KNN (non-parametric, distance-based)
Random Forest (ensemble of decision trees)
GBM (boosted trees)
SVM (Radial) (non-linear classifier for complex boundaries)
library(doParallel)
## Loading required package: foreach
## Loading required package: iterators
## Loading required package: parallel
cl <- makePSOCKcluster(4)
registerDoParallel(cl)
# Define training control and pre-processing
ctrl <- trainControl(
method = "cv",
number = 5, # 5-fold cross-validation
classProbs = TRUE,
summaryFunction = twoClassSummary
)
# Center and scale
preproc <- c("center", "scale")
# Train models
models2train <- c(
"glmnet", # Elastic Net (glmnet)
"knn", # KNN
"rf", # Random forest
"gbm", # Gradient Boosted Machines
"xgbTree", # XGboost
"svmRadial" # Support vector machines with radial kernel
)
all_models <- vector(mode = "list", length = length(models2train))
counter <- 1
for (modeli in models2train){
cat("Training model:", modeli, "\n")
modeli <- train(
x = X_train,
y = y_train,
method = modeli,
trControl = ctrl,
preProcess = preproc,
tuneLength = 3,
metric = "ROC"
)
all_models[[counter]] <- modeli
counter <- counter + 1
}
## Training model: glmnet
## Training model: knn
## Training model: rf
## Training model: gbm
## Iter TrainDeviance ValidDeviance StepSize Improve
## 1 1.3328 nan 0.1000 0.0226
## 2 1.2923 nan 0.1000 0.0114
## 3 1.2567 nan 0.1000 0.0100
## 4 1.2271 nan 0.1000 0.0118
## 5 1.1912 nan 0.1000 0.0059
## 6 1.1611 nan 0.1000 0.0069
## 7 1.1383 nan 0.1000 0.0065
## 8 1.1252 nan 0.1000 -0.0007
## 9 1.1084 nan 0.1000 0.0023
## 10 1.0921 nan 0.1000 0.0035
## 20 0.9745 nan 0.1000 -0.0010
## 40 0.8436 nan 0.1000 0.0020
## 50 0.7945 nan 0.1000 -0.0020
##
## Training model: xgbTree
## Training model: svmRadial
names(all_models) <- models2train
stopCluster(cl)
for (i in seq_len(length(all_models))){
trellis.par.set(caretTheme())
print(plot(all_models[[i]], main = names(all_models)[i]))
}
resamps <- resamples(all_models)
summary(resamps)
##
## Call:
## summary.resamples(object = resamps)
##
## Models: glmnet, knn, rf, gbm, xgbTree, svmRadial
## Number of resamples: 5
##
## ROC
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## glmnet 0.6593407 0.7514793 0.7810651 0.7756551 0.8343195 0.8520710 0
## knn 0.6035503 0.6775148 0.7455621 0.7658918 0.8324176 0.9704142 0
## rf 0.5857988 0.7455621 0.7692308 0.7676669 0.8324176 0.9053254 0
## gbm 0.4911243 0.6923077 0.7802198 0.7465765 0.8757396 0.8934911 0
## xgbTree 0.5739645 0.7041420 0.7527473 0.7138631 0.7573964 0.7810651 0
## svmRadial 0.5917160 0.7988166 0.8047337 0.7893491 0.8284024 0.9230769 0
##
## Sens
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## glmnet 0.5384615 0.6923077 0.7692308 0.7252747 0.7692308 0.8571429 0
## knn 0.5384615 0.6923077 0.6923077 0.7560440 0.8571429 1.0000000 0
## rf 0.6923077 0.6923077 0.6923077 0.7120879 0.7142857 0.7692308 0
## gbm 0.6153846 0.6153846 0.6923077 0.6802198 0.6923077 0.7857143 0
## xgbTree 0.3846154 0.6153846 0.6153846 0.6329670 0.6923077 0.8571429 0
## svmRadial 0.5384615 0.6153846 0.6923077 0.6802198 0.7692308 0.7857143 0
##
## Spec
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## glmnet 0.4615385 0.6153846 0.7692308 0.7076923 0.8461538 0.8461538 0
## knn 0.5384615 0.5384615 0.6923077 0.6615385 0.6923077 0.8461538 0
## rf 0.3846154 0.6153846 0.6923077 0.6461538 0.6923077 0.8461538 0
## gbm 0.4615385 0.6153846 0.6923077 0.6923077 0.8461538 0.8461538 0
## xgbTree 0.6923077 0.6923077 0.6923077 0.7230769 0.7692308 0.7692308 0
## svmRadial 0.3846154 0.7692308 0.7692308 0.7384615 0.8461538 0.9230769 0
theme1 <- trellis.par.get()
theme1$plot.symbol$col = rgb(.2, .2, .2, .4)
theme1$plot.symbol$pch = 16
theme1$plot.line$col = rgb(1, 0, 0, .7)
theme1$plot.line$lwd <- 2
trellis.par.set(theme1)
bwplot(resamps, layout = c(3, 1))
# Evaluate statistical significance of differences
trellis.par.set(caretTheme())
dotplot(resamps, metric = "ROC")
difValues <- diff(resamps)
difValues
##
## Call:
## diff.resamples(x = resamps)
##
## Models: glmnet, knn, rf, gbm, xgbTree, svmRadial
## Metrics: ROC, Sens, Spec
## Number of differences: 15
## p-value adjustment: bonferroni
summary(difValues)
##
## Call:
## summary.diff.resamples(object = difValues)
##
## p-value adjustment: bonferroni
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
##
## ROC
## glmnet knn rf gbm xgbTree svmRadial
## glmnet 0.009763 0.007988 0.029079 0.061792 -0.013694
## knn 1 -0.001775 0.019315 0.052029 -0.023457
## rf 1 1 0.021090 0.053804 -0.021682
## gbm 1 1 1 0.032713 -0.042773
## xgbTree 1 1 1 1 -0.075486
## svmRadial 1 1 1 1 1
##
## Sens
## glmnet knn rf gbm xgbTree svmRadial
## glmnet -0.03077 0.01319 0.04505 0.09231 0.04505
## knn 1 0.04396 0.07582 0.12308 0.07582
## rf 1 1 0.03187 0.07912 0.03187
## gbm 1 1 1 0.04725 0.00000
## xgbTree 1 1 1 1 -0.04725
## svmRadial 1 1 1 1 1
##
## Spec
## glmnet knn rf gbm xgbTree svmRadial
## glmnet 0.04615 0.06154 0.01538 -0.01538 -0.03077
## knn 1 0.01538 -0.03077 -0.06154 -0.07692
## rf 1 1 -0.04615 -0.07692 -0.09231
## gbm 1 1 1 -0.03077 -0.04615
## xgbTree 1 1 1 1 -0.01538
## svmRadial 1 1 1 1 1
trellis.par.set(theme1)
bwplot(difValues, layout = c(3, 1))
We will use confusion matrices, classification reports and the Area Under the ROC curve to evaluate the performance of our models to the test set.
for (i in 1:length(all_models)) {
preds <- predict(all_models[[i]], X_test)
cat("\n##############################")
cat("\nModel:", names(all_models)[i], "\n")
print(confusionMatrix(preds, y_test))
}
##
## ##############################
## Model: glmnet
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 24 3
## Sensitive 4 24
##
## Accuracy : 0.8727
## 95% CI : (0.7552, 0.9473)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 1.375e-08
##
## Kappa : 0.7455
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8571
## Specificity : 0.8889
## Pos Pred Value : 0.8889
## Neg Pred Value : 0.8571
## Prevalence : 0.5091
## Detection Rate : 0.4364
## Detection Prevalence : 0.4909
## Balanced Accuracy : 0.8730
##
## 'Positive' Class : Resistant
##
##
## ##############################
## Model: knn
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 24 6
## Sensitive 4 21
##
## Accuracy : 0.8182
## 95% CI : (0.691, 0.9092)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 1.945e-06
##
## Kappa : 0.6358
##
## Mcnemar's Test P-Value : 0.7518
##
## Sensitivity : 0.8571
## Specificity : 0.7778
## Pos Pred Value : 0.8000
## Neg Pred Value : 0.8400
## Prevalence : 0.5091
## Detection Rate : 0.4364
## Detection Prevalence : 0.5455
## Balanced Accuracy : 0.8175
##
## 'Positive' Class : Resistant
##
##
## ##############################
## Model: rf
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 24 6
## Sensitive 4 21
##
## Accuracy : 0.8182
## 95% CI : (0.691, 0.9092)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 1.945e-06
##
## Kappa : 0.6358
##
## Mcnemar's Test P-Value : 0.7518
##
## Sensitivity : 0.8571
## Specificity : 0.7778
## Pos Pred Value : 0.8000
## Neg Pred Value : 0.8400
## Prevalence : 0.5091
## Detection Rate : 0.4364
## Detection Prevalence : 0.5455
## Balanced Accuracy : 0.8175
##
## 'Positive' Class : Resistant
##
##
## ##############################
## Model: gbm
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 25 6
## Sensitive 3 21
##
## Accuracy : 0.8364
## 95% CI : (0.712, 0.9223)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 4.245e-07
##
## Kappa : 0.672
##
## Mcnemar's Test P-Value : 0.505
##
## Sensitivity : 0.8929
## Specificity : 0.7778
## Pos Pred Value : 0.8065
## Neg Pred Value : 0.8750
## Prevalence : 0.5091
## Detection Rate : 0.4545
## Detection Prevalence : 0.5636
## Balanced Accuracy : 0.8353
##
## 'Positive' Class : Resistant
##
##
## ##############################
## Model: xgbTree
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 21 5
## Sensitive 7 22
##
## Accuracy : 0.7818
## 95% CI : (0.6499, 0.8819)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 2.915e-05
##
## Kappa : 0.5641
##
## Mcnemar's Test P-Value : 0.7728
##
## Sensitivity : 0.7500
## Specificity : 0.8148
## Pos Pred Value : 0.8077
## Neg Pred Value : 0.7586
## Prevalence : 0.5091
## Detection Rate : 0.3818
## Detection Prevalence : 0.4727
## Balanced Accuracy : 0.7824
##
## 'Positive' Class : Resistant
##
##
## ##############################
## Model: svmRadial
## Confusion Matrix and Statistics
##
## Reference
## Prediction Resistant Sensitive
## Resistant 21 3
## Sensitive 7 24
##
## Accuracy : 0.8182
## 95% CI : (0.691, 0.9092)
## No Information Rate : 0.5091
## P-Value [Acc > NIR] : 1.945e-06
##
## Kappa : 0.6372
##
## Mcnemar's Test P-Value : 0.3428
##
## Sensitivity : 0.7500
## Specificity : 0.8889
## Pos Pred Value : 0.8750
## Neg Pred Value : 0.7742
## Prevalence : 0.5091
## Detection Rate : 0.3818
## Detection Prevalence : 0.4364
## Balanced Accuracy : 0.8194
##
## 'Positive' Class : Resistant
##
library(ggplot2)
library(pROC)
# Calculate ROCs and AUCs on test data
roc_list <- lapply(names(all_models), function(name) {
probs <- predict(all_models[[name]], X_test, type = "prob")[, "Sensitive"]
roc_obj <- roc(y_test, probs)
auc_val <- auc(roc_obj)
return(roc_obj)
})
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
aucs_holdout <- sapply(roc_list, auc)
names(roc_list) <- paste0(models2train, " (AUC = ", round(aucs_holdout, 3), ")")
ggroc(roc_list) + theme_minimal()
# Extract variable importances
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
varimps_list <- lapply(all_models, function(modeli) {
vi <- varImp(modeli)$importance
vi$Feature <- rownames(vi)
return(vi)
})
# Name each model
names(varimps_list) <- models2train
# Merge all into a long format
vi_long <- rbindlist(
lapply(names(varimps_list), function(namei) {
dt <- as.data.table(varimps_list[[namei]])
dt[, Model := namei]
return(dt)
}),
use.names = TRUE, fill = TRUE
)
# For models that have multiple classes (e.g., "Sensitive", "Resistant"),
# take the average importance across classes if necessary
vi_long[, my_overall := ifelse(test = is.na(Overall),
yes = Resistant,
no = Overall)
]
vi_long_melted <- dcast.data.table(data = vi_long, formula = Model ~ Feature,
value.var = "my_overall")
vi_long_melted_mat <- as.matrix(vi_long_melted[, -"Model"])
rownames(vi_long_melted_mat) <- vi_long_melted[, Model]
mean_imp_all_models <- sort(
x = apply(
X = vi_long_melted_mat,
MARGIN = 2,
FUN = mean
),
decreasing = TRUE)
# Plot
number_of_genes2plot <- 20
vi2plot_mat <- vi_long_melted_mat[,
colnames(vi_long_melted_mat) %in% names(mean_imp_all_models)[1:number_of_genes2plot]]
library(pheatmap)
pheatmap(
mat = vi2plot_mat,
cluster_rows = TRUE, # Cluster features
cluster_cols = TRUE, # Cluster models
scale = "none", # Do not scale the data (optional: could use "row" or "column" if needed)
fontsize_row = 8,
fontsize_col = 10,
treeheight_row = 50,
treeheight_col = 50,
main = "Feature Importance Across Models"
)
plot(varImp(all_models$rf), top = 20,
main = "Variable importance")
H2O AutoML can take time depending on dataset size. For tutorial purposes, we limit the number of models to 5. AutoML (Automated Machine Learning) automates the process of training and tuning multiple models, including ensembles, to find the best-performing one with minimal manual effort.
library(h2o)
##
## ----------------------------------------------------------------------
##
## Your next step is to start H2O:
## > h2o.init()
##
## For H2O package documentation, ask for help:
## > ??h2o
##
## After starting H2O, you can use the Web UI at http://localhost:54321
## For more information visit https://docs.h2o.ai
##
## ----------------------------------------------------------------------
##
## Attaching package: 'h2o'
## The following object is masked from 'package:pROC':
##
## var
## The following objects are masked from 'package:data.table':
##
## hour, month, week, year
## The following objects are masked from 'package:stats':
##
## cor, sd, var
## The following objects are masked from 'package:base':
##
## &&, %*%, %in%, ||, apply, as.factor, as.numeric, colnames,
## colnames<-, ifelse, is.character, is.factor, is.numeric, log,
## log10, log1p, log2, round, signif, trunc
h2o.init()
##
## H2O is not running yet, starting it now...
##
## Note: In case of errors look at the following log files:
## /var/folders/px/vgzvyptn3lvbx79g9smcm47c0000gn/T//RtmpsSsgZ8/file89e30d29e3a/h2o_nestoraskarathanasis_started_from_r.out
## /var/folders/px/vgzvyptn3lvbx79g9smcm47c0000gn/T//RtmpsSsgZ8/file89e53920ae0/h2o_nestoraskarathanasis_started_from_r.err
##
##
## Starting H2O JVM and connecting: .. Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 2 seconds 634 milliseconds
## H2O cluster timezone: Asia/Nicosia
## H2O data parsing timezone: UTC
## H2O cluster version: 3.44.0.3
## H2O cluster version age: 1 year, 4 months and 23 days
## H2O cluster name: H2O_started_from_R_nestoraskarathanasis_tjb904
## H2O cluster total nodes: 1
## H2O cluster total memory: 4.00 GB
## H2O cluster total cores: 8
## H2O cluster allowed cores: 8
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## R Version: R version 4.4.2 (2024-10-31)
## Warning in h2o.clusterInfo():
## Your H2O cluster version is (1 year, 4 months and 23 days) old. There may be a newer version available.
## Please download and install the latest version from: https://h2o-release.s3.amazonaws.com/h2o/latest_stable.html
train_h2o <- as.h2o(data.frame(X_train, auc_binary = as.factor(y_train)))
## Warning in use.package("data.table"): data.table cannot be used without R
## package bit64 version 0.9.7 or higher. Please upgrade to take advangage of
## data.table speedups.
## | | | 0% | |======================================================================| 100%
test_h2o <- as.h2o(data.frame(X_test, auc_binary = as.factor(y_test)))
## Warning in use.package("data.table"): data.table cannot be used without R
## package bit64 version 0.9.7 or higher. Please upgrade to take advangage of
## data.table speedups.
## | | | 0% | |======================================================================| 100%
aml <- h2o.automl(
x = colnames(X_train),
y = "auc_binary",
training_frame = train_h2o,
max_models = 5,
seed = 42
)
## | | | 0% | |==== | 6%
## 17:14:21.665: AutoML: XGBoost is not available; skipping it.
## 17:14:22.75: _min_rows param, The dataset size is too small to split for min_rows=100.0: must have at least 200.0 (weighted) rows, but have only 131.0. | |======================================================================| 100%
lb <- aml@leaderboard
print(lb)
## model_id auc logloss
## 1 GLM_1_AutoML_1_20250513_171421 0.7776224 0.5699923
## 2 DRF_1_AutoML_1_20250513_171421 0.7710956 0.5748109
## 3 StackedEnsemble_BestOfFamily_1_AutoML_1_20250513_171421 0.7650350 0.5912695
## 4 StackedEnsemble_AllModels_1_AutoML_1_20250513_171421 0.7601399 0.5885814
## 5 GBM_2_AutoML_1_20250513_171421 0.7317016 0.6333007
## 6 GBM_4_AutoML_1_20250513_171421 0.7275058 0.6152420
## aucpr mean_per_class_error rmse mse
## 1 0.7790839 0.2516317 0.4390031 0.1927237
## 2 0.7567653 0.2969697 0.4414689 0.1948948
## 3 0.7416697 0.2589744 0.4478647 0.2005828
## 4 0.7457218 0.2440559 0.4468220 0.1996499
## 5 0.7412831 0.2671329 0.4637726 0.2150850
## 6 0.7586785 0.2900932 0.4586345 0.2103456
##
## [7 rows x 7 columns]
# To generate predictions on a test set, you can make predictions
# directly on the `H2OAutoML` object or on the leader model
# object directly
pred <- h2o.predict(aml, test_h2o) # predict(aml, test) also works
## | | | 0% | |======================================================================| 100%
pred
## predict Resistant Sensitive
## 1 Resistant 0.7923473 0.2076527
## 2 Sensitive 0.3015856 0.6984144
## 3 Resistant 0.8356658 0.1643342
## 4 Sensitive 0.1205052 0.8794948
## 5 Sensitive 0.3398128 0.6601872
## 6 Sensitive 0.2418082 0.7581918
##
## [55 rows x 3 columns]
h2o.performance(model = aml@leader, newdata = test_h2o)
## H2OBinomialMetrics: glm
##
## MSE: 0.1297629
## RMSE: 0.3602262
## LogLoss: 0.4250307
## Mean Per-Class Error: 0.1269841
## AUC: 0.8994709
## AUCPR: 0.8168489
## Gini: 0.7989418
## R^2: 0.4807766
## Residual Deviance: 46.75338
## AIC: 148.7534
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## Resistant Sensitive Error Rate
## Resistant 24 4 0.142857 =4/28
## Sensitive 3 24 0.111111 =3/27
## Totals 27 28 0.127273 =7/55
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.504956 0.872727 27
## 2 max f2 0.319040 0.902778 35
## 3 max f0point5 0.550677 0.877863 25
## 4 max accuracy 0.550677 0.872727 25
## 5 max precision 0.709413 0.928571 13
## 6 max recall 0.204469 1.000000 42
## 7 max specificity 0.917052 0.964286 0
## 8 max absolute_mcc 0.504956 0.746032 27
## 9 max min_per_class_accuracy 0.504956 0.857143 27
## 10 max mean_per_class_accuracy 0.504956 0.873016 27
## 11 max tns 0.917052 27.000000 0
## 12 max fns 0.917052 27.000000 0
## 13 max fps 0.020601 28.000000 54
## 14 max tps 0.204469 27.000000 42
## 15 max tnr 0.917052 0.964286 0
## 16 max fnr 0.917052 1.000000 0
## 17 max fpr 0.020601 1.000000 54
## 18 max tpr 0.204469 1.000000 42
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
Explain leader model & compare with all AutoML models
exa <- h2o.explain(aml, test_h2o)
exa
##
##
## Leaderboard
## ===========
##
## > Leaderboard shows models with their metrics. When provided with H2OAutoML object, the leaderboard shows 5-fold cross-validated metrics by default (depending on the H2OAutoML settings), otherwise it shows metrics computed on the newdata. At most 20 models are shown by default.
##
##
## | | model_id | auc | logloss | aucpr | mean_per_class_error | rmse | mse | training_time_ms | predict_time_per_row_ms | algo
## |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
## | **1** |GLM_1_AutoML_1_20250513_171421 | 0.777622377622378 | 0.569992322805128 | 0.779083925862759 | 0.251631701631702 | 0.439003050049449 | 0.192723677952719 | 24 | 0.010253 | GLM |
## | **2** |DRF_1_AutoML_1_20250513_171421 | 0.771095571095571 | 0.574810866462091 | 0.75676534427548 | 0.296969696969697 | 0.441468888358538 | 0.194894779388523 | 62 | 0.011994 | DRF |
## | **3** |StackedEnsemble_BestOfFamily_1_AutoML_1_20250513_171421 | 0.765034965034965 | 0.591269492015068 | 0.741669678006428 | 0.258974358974359 | 0.447864681629856 | 0.200582773051412 | 435 | 0.020573 | StackedEnsemble |
## | **4** |StackedEnsemble_AllModels_1_AutoML_1_20250513_171421 | 0.76013986013986 | 0.588581434289178 | 0.745721823652432 | 0.244055944055944 | 0.44682202373744 | 0.199649920896822 | 318 | 0.023425 | StackedEnsemble |
## | **5** |GBM_2_AutoML_1_20250513_171421 | 0.731701631701632 | 0.633300712627457 | 0.741283122632491 | 0.267132867132867 | 0.463772582366536 | 0.215085008154925 | 98 | 0.013958 | GBM |
## | **6** |GBM_4_AutoML_1_20250513_171421 | 0.727505827505828 | 0.615242036374204 | 0.758678528249585 | 0.29009324009324 | 0.45863448252923 | 0.210345588564854 | 52 | 0.009748 | GBM |
## | **7** |GBM_3_AutoML_1_20250513_171421 | 0.722377622377622 | 0.639476326162786 | 0.680150321289085 | 0.31958041958042 | 0.465295969451299 | 0.216500339187625 | 85 | 0.01019 | GBM |
##
##
## Confusion Matrix
## ================
##
## > Confusion matrix shows a predicted class vs an actual class.
##
##
##
## GLM_1_AutoML_1_20250513_171421
## ------------------------------
##
## | | Resistant | Sensitive | Error | Rate
## |:---:|:---:|:---:|:---:|:---:|
## | **Resistant** |24 | 4 | 0.142857142857143 | =4/28 |
## | **Sensitive** |3 | 24 | 0.111111111111111 | =3/27 |
## | **Totals** |27 | 28 | 0.127272727272727 | =7/55 |
##
##
## Learning Curve Plot
## ===================
##
## > Learning curve plot shows the loss function/metric dependent on number of iterations or trees for tree-based algorithms. This plot can be useful for determining whether the model overfits.
##
##
## Variable Importance
## ===================
##
## > The variable importance plot shows the relative importance of the most important variables in the model.
##
##
## Variable Importance Heatmap
## ===========================
##
## > Variable importance heatmap shows variable importance across multiple models. Some models in H2O return variable importance for one-hot (binary indicator) encoded versions of categorical columns (e.g. Deep Learning, XGBoost). In order for the variable importance of categorical columns to be compared across all model types we compute a summarization of the the variable importance across all one-hot encoded features and return a single variable importance for the original categorical feature. By default, the models and variables are ordered by their similarity.
##
##
## Model Correlation
## =================
##
## > This plot shows the correlation between the predictions of the models. For classification, frequency of identical predictions is used. By default, models are ordered by their similarity (as computed by hierarchical clustering).
## Interpretable models: GLM_1_AutoML_1_20250513_171421
##
##
## SHAP Summary
## ============
##
## > SHAP summary plot shows the contribution of the features for each instance (row of data). The sum of the feature contributions and the bias term is equal to the raw prediction of the model, i.e., prediction before applying inverse link function.
##
##
## Partial Dependence Plots
## ========================
##
## > Partial dependence plot (PDP) gives a graphical depiction of the marginal effect of a variable on the response. The effect of a variable is measured in change in the mean response. PDP assumes independence between the feature for which is the PDP computed and the rest.
Explain a single H2O model (e.g. leader model from AutoML)
# Get the leaderboard
lb <- aml@leaderboard
# Get the ID of the second model
second_model_id <- as.data.frame(lb$model_id)[2, 1]
# Retrieve the model
model2explain <- h2o.getModel(second_model_id)
# Explain the model
exm <- h2o.explain(model2explain, test_h2o)
exm
##
##
## Confusion Matrix
## ================
##
## > Confusion matrix shows a predicted class vs an actual class.
##
##
##
## DRF_1_AutoML_1_20250513_171421
## ------------------------------
##
## | | Resistant | Sensitive | Error | Rate
## |:---:|:---:|:---:|:---:|:---:|
## | **Resistant** |19 | 9 | 0.321428571428571 | =9/28 |
## | **Sensitive** |1 | 26 | 0.037037037037037 | =1/27 |
## | **Totals** |20 | 35 | 0.181818181818182 | =10/55 |
##
##
## Learning Curve Plot
## ===================
##
## > Learning curve plot shows the loss function/metric dependent on number of iterations or trees for tree-based algorithms. This plot can be useful for determining whether the model overfits.
##
##
## Variable Importance
## ===================
##
## > The variable importance plot shows the relative importance of the most important variables in the model.
##
##
## SHAP Summary
## ============
##
## > SHAP summary plot shows the contribution of the features for each instance (row of data). The sum of the feature contributions and the bias term is equal to the raw prediction of the model, i.e., prediction before applying inverse link function.
##
##
## Partial Dependence Plots
## ========================
##
## > Partial dependence plot (PDP) gives a graphical depiction of the marginal effect of a variable on the response. The effect of a variable is measured in change in the mean response. PDP assumes independence between the feature for which is the PDP computed and the rest.
In this tutorial, we explored how to:
careth2oParticipants are encouraged to experiment with:
Happy modeling!!